{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iteration: 1, error: 0.229, fairness violation: 0.05428400000000001, violated group size: 0.249\n",
      "iteration: 2, error: 0.3645, fairness violation: 0.027142000000000006, violated group size: 0.249\n",
      "iteration: 3, error: 0.4096666666666666, fairness violation: 0.01809466666666667, violated group size: 0.251\n",
      "iteration: 4, error: 0.43225, fairness violation: 0.013571000000000003, violated group size: 0.249\n",
      "iteration: 5, error: 0.44580000000000014, fairness violation: 0.0108568, violated group size: 0.251\n",
      "iteration: 6, error: 0.4548333333333334, fairness violation: 0.009047333333333338, violated group size: 0.251\n",
      "iteration: 7, error: 0.46128571428571435, fairness violation: 0.007754857142857144, violated group size: 0.251\n",
      "iteration: 8, error: 0.466125, fairness violation: 0.006785500000000003, violated group size: 0.251\n",
      "iteration: 9, error: 0.469888888888889, fairness violation: 0.006031555555555558, violated group size: 0.249\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from aif360.algorithms.inprocessing import GerryFairClassifier\n",
    "from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple\n",
    "from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult\n",
    "from sklearn import svm\n",
    "from sklearn import tree\n",
    "from sklearn.kernel_ridge import KernelRidge\n",
    "from sklearn import linear_model\n",
    "from aif360.metrics import BinaryLabelDatasetMetric\n",
    "from IPython.display import Image\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# load data set\n",
    "data_set = load_preproc_data_adult(sub_samp=1000, balance=True)\n",
    "max_iterations = 10\n",
    "C = 100\n",
    "print_flag = True\n",
    "gamma = .005\n",
    "\n",
    "fair_model = GerryFairClassifier(C=C, printflag=print_flag, gamma=gamma, fairness_def='FP',\n",
    "             max_iters=max_iterations, heatmapflag=False)\n",
    "# fit method\n",
    "fair_model.fit(data_set, early_termination=True)\n",
    "\n",
    "# predict method. If threshold in (0, 1) produces binary predictions\n",
    "dataset_yhat = fair_model.predict(data_set, threshold=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0060315555555555565\n"
     ]
    }
   ],
   "source": [
    "# auditing \n",
    "\n",
    "gerry_metric = BinaryLabelDatasetMetric(data_set)\n",
    "gamma_disparity = gerry_metric.rich_subgroup(array_to_tuple(dataset_yhat.labels), 'FP')\n",
    "print(gamma_disparity)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Curr Predictor: Linear\n",
      "Curr Predictor: SVR\n",
      "Curr Predictor: Tree\n",
      "Curr Predictor: Kernel\n"
     ]
    }
   ],
   "source": [
    "# set to 10 iterations for fast running of notebook - set >= 1000 when running real experiments\n",
    "# tests learning with different hypothesis classes\n",
    "pareto_iters = 10\n",
    "def multiple_classifiers_pareto(dataset, gamma_list=[0.002, 0.005, 0.01], save_results=False, iters=pareto_iters):\n",
    "\n",
    "    ln_predictor = linear_model.LinearRegression()\n",
    "    svm_predictor = svm.LinearSVR()\n",
    "    tree_predictor = tree.DecisionTreeRegressor(max_depth=3)\n",
    "    kernel_predictor = KernelRidge(alpha=1.0, gamma=1.0, kernel='rbf')\n",
    "    predictor_dict = {'Linear': {'predictor': ln_predictor, 'iters': iters},\n",
    "                      'SVR': {'predictor': svm_predictor, 'iters': iters},\n",
    "                      'Tree': {'predictor': tree_predictor, 'iters': iters},\n",
    "                      'Kernel': {'predictor': kernel_predictor, 'iters': iters}}\n",
    "\n",
    "    results_dict = {}\n",
    "\n",
    "    for pred in predictor_dict:\n",
    "        print('Curr Predictor: {}'.format(pred))\n",
    "        predictor = predictor_dict[pred]['predictor']\n",
    "        max_iters = predictor_dict[pred]['iters']\n",
    "        fair_clf = GerryFairClassifier(C=100, printflag=True, gamma=1, predictor=predictor, max_iters=max_iters)\n",
    "        fair_clf.printflag = False\n",
    "        fair_clf.max_iters=max_iters\n",
    "        errors, fp_violations, fn_violations = fair_clf.pareto(dataset, gamma_list)\n",
    "        results_dict[pred] = {'errors': errors, 'fp_violations': fp_violations, 'fn_violations': fn_violations}\n",
    "    if save_results:\n",
    "        pickle.dump(results_dict, open('results_dict_' + str(gamma_list) + '_gammas' + str(gamma_list) + '.pkl', 'wb'))\n",
    "\n",
    "multiple_classifiers_pareto(data_set)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
